{ "cells": [ { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "import pandas as pd\n", "import math, os" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "df = pd.read_csv('Groceries.csv')" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "commodity_list = df['items'].values" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "for i in range(len(commodity_list)):\n", " commodity_list[i] = set(commodity_list[i][1 : -1].split(','))" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "# 创建初始的每一个物品单独作为项集的项集列表\n", "def createInit(data):\n", " st = []\n", " for commodity_set in data:\n", " for item in commodity_set:\n", " if [item] not in st:\n", " st.append([item])\n", " st.sort()\n", " return list(map(frozenset, st))" ] }, { "cell_type": "code", "execution_count": 26, "metadata": {}, "outputs": [], "source": [ "# 扫描数据集,过滤支持度小于阈值的项集\n", "def scanDataset(data, commodity_set, min_support):\n", " cnt = {}\n", " # 对项集在数据集行中出现次数进行计数\n", " for commodity_list in data:\n", " for line in commodity_set:\n", " if line.issubset(commodity_list):\n", " if not cnt.__contains__(line):\n", " cnt[line] = 1\n", " else:\n", " cnt[line] += 1\n", " sum = np.float(len(data))\n", " eligible_set = [] # 合格的项集\n", " support_data = {} # 项集到支持的映射\n", " for key, value in cnt.items():\n", " support = value / sum\n", " if support >= min_support:\n", " eligible_set.append(key)\n", " support_data[key] = support\n", " return eligible_set, support_data" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "# 由旧的频繁集生成每个集合物品个数+1的新频繁集\n", "def generateNewSet(old_set, Len):\n", " new_list = []\n", " length = len(old_set)\n", " for i in range(length):\n", " for j in range(i + 1, length):\n", " l1, l2 = sorted(list(old_set[i]))[ : Len - 2], sorted(list(old_set[j]))[ : Len - 2]\n", " if l1 == l2:\n", " new_list.append(old_set[i] | old_set[j])\n", " return new_list" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [], "source": [ "def apriori(data, min_support):\n", " init_set = createInit(data)\n", " D = list(map(set, data))\n", " eligible_set, support_data = scanDataset(data, init_set, min_support)\n", " L = [eligible_set]\n", " k = 2\n", " while len(L[k - 2]) > 0:\n", " nxt = generateNewSet(L[k - 2], k)\n", " lk, supportData = scanDataset(data, nxt, min_support)\n", " support_data.update(supportData)\n", " L.append(lk)\n", " k += 1\n", " return L, support_data" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [], "source": [ "# 验证\n", "data = [{1, 3, 4}, {2, 3, 5}, {1, 2, 3, 5}, {2, 5}]\n", "L, support_data = apriori(data, 0.5)\n" ] }, { "cell_type": "code", "execution_count": 107, "metadata": {}, "outputs": [], "source": [ "def calConfidence(freq_set, h, support_data, rule_map, min_conf=0.5):\n", " rules = []\n", " for seq in h:\n", " if 'whole milk' in seq and 'root vegetables' in freq_set - seq and 'butter' in freq_set - seq:\n", " print(freq_set - seq, ' ', seq)\n", " print(support_data[freq_set], ' ', support_data[freq_set - seq]) \n", " confidence = support_data[freq_set] / support_data[freq_set - seq]\n", " print(freq_set - seq, ' --> ', seq, ' ', confidence)\n", " if confidence >= min_conf:\n", " rule_map.append((freq_set - seq, seq, confidence))\n", " rules.append(seq)\n", " return rules" ] }, { "cell_type": "code", "execution_count": 67, "metadata": {}, "outputs": [], "source": [ "def rulesFromConsequnce(freq_set, h, support_data, rule_map, min_conf=0.5):\n", " length = len(h[0]) # h是可以放在规则右边的词汇集合\n", " if len(freq_set) > length + 1:\n", " nxt = generateNewSet(h, length + 1)\n", " nxt = calConfidence(freq_set, nxt, support_data, rule_map, min_conf)\n", " if len(nxt) > 1:\n", " rulesFromConsequnce(freq_set, nxt, support_data, rule_map, min_conf)" ] }, { "cell_type": "code", "execution_count": 68, "metadata": {}, "outputs": [], "source": [ "def generateRules(L, support_data, min_conf=0.5):\n", " total_rules = []\n", " for i in range(1, len(L)):\n", " for freq_set in L[i]:\n", " h = [frozenset([item]) for item in freq_set]\n", " if i > 1:\n", " rulesFromConsequnce(freq_set, h, support_data, total_rules, min_conf)\n", " else:\n", " calConfidence(freq_set, h, support_data, total_rules, min_conf)\n", " return total_rules" ] }, { "cell_type": "code", "execution_count": 140, "metadata": {}, "outputs": [], "source": [ "def generaterules(L, support_data, min_conf=0.5):\n", " total_rules = []\n", " for i in range(1, 3):\n", " for freq_set in L[i]:\n", " for item in freq_set:\n", " right = frozenset([item])\n", " confidence = support_data[freq_set] / support_data[freq_set - right]\n", " if confidence >= min_conf:\n", " total_rules.append((freq_set - right, right, confidence))\n", " return total_rules" ] }, { "cell_type": "code", "execution_count": 27, "metadata": {}, "outputs": [], "source": [ "L, support_data = apriori(commodity_list, 0.005)" ] }, { "cell_type": "code", "execution_count": 141, "metadata": {}, "outputs": [], "source": [ "ans = generaterules(L, support_data)" ] }, { "cell_type": "code", "execution_count": 39, "metadata": {}, "outputs": [], "source": [ "ans_map = {}\n", "for i in range(1, 4):\n", " ans_map[i] = []\n", "for key, val in support_data.items():\n", " if len(key) <= 3:\n", " ans_map[len(key)].append((key, val))" ] }, { "cell_type": "code", "execution_count": 147, "metadata": {}, "outputs": [], "source": [ "with open('ans.txt', 'w') as w:\n", " for i in range(1, 4):\n", " w.write('-------------L' + str(i) + '-------------' + '\\n')\n", " for tp in ans_map[i]:\n", " w.write('[')\n", " for word in tp[0]:\n", " w.write(word + ', ')\n", " w.write('] ')\n", " w.write(str(tp[1]) + '\\n')\n", " # w.write(str(tp[0]) + ' ' + str(tp[1]) + '\\n'\n", " w.write('total: ' + str(len(ans_map[i])) + '\\n')\n", " for grp in ans:\n", " w.write('[')\n", " for word in grp[0]:\n", " w.write(word + ',')\n", " w.write('] --> ')\n", " for word in grp[1]:\n", " w.write('[' + word + ']')\n", " w.write(' conf: ' + str(grp[2]) + '\\n')\n", " w.write(str(len(ans)))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "interpreter": { "hash": "73e8fb81fc9d21637ba62ed4f9412d39843bbeeb61edb8163afd2f9314d52c65" }, "kernelspec": { "display_name": "Python 3.7.6 64-bit (system)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.6" }, "orig_nbformat": 4 }, "nbformat": 4, "nbformat_minor": 2 }